# -*- coding: utf-8 -*-
"""
Created on Wed Mar 29 09:47:36 2023

@author: lv
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class PositionalEncoding(nn.Module):
    def __init__(self, max_seq_len, d_model):
        super(PositionalEncoding, self).__init__()
        self.max_seq_len = max_seq_len
        self.d_model = d_model

        # 创建位置编码矩阵
        pos_enc = torch.zeros(max_seq_len, d_model)

        # 计算位置编码
        for pos in range(max_seq_len):
            for i in range(0, d_model, 2):
                pos_enc[pos, i] = math.sin(pos / (10000 ** ((2 * i) / d_model)))
                pos_enc[pos, i + 1] = math.cos(pos / (10000 ** ((2 * i) / d_model)))

        # 添加可学习参数，并将位置编码矩阵注册为buffer
        self.register_buffer('pos_enc', pos_enc)

    def forward(self, x):
        """
        :param x: shape=(batch_size, seq_len, d_model)
        :return: shape=(batch_size, seq_len, d_model)
        """
        batch_size, seq_len, d_model = x.shape

        self.pos_enc.to(x.device)
        
        # 将位置编码矩阵广播到batch_size
        pos_enc = self.pos_enc.unsqueeze(0).repeat(batch_size, 1, 1)

        # 将位置编码矩阵与输入张量相加
        x = x + pos_enc

        return x


class SelfAttentionEncoder(nn.Module):
    def __init__(
        self,
        max_seq_length,
        input_max_num_sequences,
        hidden_dims=[10240, 1024, 5120],
        num_digits=6
    ):
        super(SelfAttentionEncoder, self).__init__()

        # 序列最大长度
        self.max_seq_length = max_seq_length

        # 输入样本最大长度
        self.input_max_num_sequences = input_max_num_sequences
    
        # 实例化位置编码层
        self.positional_encoding = PositionalEncoding(max_seq_length, num_digits)
        
        # 全连接层1 以位置归类的词汇关系特征层
        self.fc1 = nn.Linear(
            input_max_num_sequences * max_seq_length * num_digits,
            hidden_dims[0],
        )
        

        # 全连接层2 词组关系高阶特征层
        self.fc2 = nn.Linear(
            hidden_dims[0],
            hidden_dims[1],
        )

        # 注意力层
        self.attention = nn.Linear(hidden_dims[1], num_digits*10, bias=False)
        
        self.num_digits = num_digits
        
        self.dropout1 = nn.Dropout(0.8)
        self.dropout2 = nn.Dropout(0.5)
        self.dropout3 = nn.Dropout(0.25)


    def forward(self, input_ids):
        #print(input_ids.shape)
        # 获取输入序列的形状
        batch_size, seq_length , embedded_size = input_ids.size()
        
        assert self.input_max_num_sequences >= batch_size
    
        # 对输入序列进行填充，补0到 max_seq_length 长度
        input_ids = F.pad(input_ids, (0, self.max_seq_length - seq_length), "constant", 0)
        
        # 进行位置编码
        embedded = self.positional_encoding(input_ids)
        
        # 加上位置编码
        input_ids = input_ids + embedded
        
        input_ids = input_ids.view(-1)#有了位置编码就可以扁平计算

        # 全连接层1
        fc1_output = self.dropout1(self.fc1(input_ids))
        
        # 全连接层2
        fc2_output = self.dropout2(self.fc2(fc1_output))
        
        # 注意力层
        attention_logits = self.dropout3(self.attention(fc2_output))  # [num_digits*10]
        
        attention_softmax = F.softmax(attention_logits.view(self.num_digits,10))
        
        return  attention_softmax



